''''
Answer question Q2: how well can states derived from the status features of the users predict states after the chosen activities?

Author: Meng Zhang
Date: January 2024
Disclaimer: adapted from the analysis code https://doi.org/10.4121/22153898.v1

Input: RL_trasition_weighted_reward.csv
Output: Figure 2
'''

import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
import Utils as util
import Calculate_Q_values as cal_q

### LOAD DATA
NUM_ACTIONS = 14
df_weighted, mean, min, max = util.weighted_sum_of_reward__for_transitions(0.5)
data = pd.read_csv("RL_trasition_weighted_reward.csv", converters={'Binary_State': eval,'Binary_State_Next_Session': eval})
all_people = list(set(data['rand_id'].tolist()))
NUM_PEOPLE = len(all_people)
print("Total number of samples: " + str(len(data)) + ".")
print("Total number of people: " + str(NUM_PEOPLE) + ".")

#### FEATURE SELECTION
NUM_FEAT_TO_SELECT = 3
OUTPUT_LOWER = -1
OUTPUT_HIGHER = 1
CANDIDATE_FEATURES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data_train = data.copy(deep=True)
reward_mean = mean
map_to_rewards = util.get_map_effort_reward(reward_mean, OUTPUT_LOWER, OUTPUT_HIGHER, min, max)
df_feat = data_train.drop(columns=['rand_id', 'session_num'])
data_feat = df_feat.values.tolist()
feat_sel = cal_q.feature_selection(data_feat, mean, min, max, CANDIDATE_FEATURES, NUM_FEAT_TO_SELECT)
print("Features selected:", feat_sel, "Weighted_mean", reward_mean)



### Compute the mean likelihood of next states
# We now compute the mean likelihood of next states based on both the estimated transition function
# and assuming that people stay in their current state.
# We use a form of leave-one-out cross-validation, in that we use the samples from all but one person as training data,
# and then compute the likelihood of next states for the samples from the left-out person.
discount_factor = 0.85
prob_per_state = []  # Likelihood of next state based on transition function
prob_per_state_stay = []  # Likelihood of next state when assuming that people stay in their current state
for state in range(2**NUM_FEAT_TO_SELECT):
    prob_per_state.append([])
    prob_per_state_stay.append([])

abstract_states = [list(i) for i in itertools.product([0, 1], repeat = NUM_FEAT_TO_SELECT)]

for i in range(NUM_PEOPLE):

    # Use all data except samples from the current person
    data_train_temp = data_train[data_train['rand_id'] != all_people[i]]
    data_train_temp_samples = data_train_temp[["Binary_State", "Binary_State_Next_Session", "cluster_new_index", "weighted_reward"]].values.tolist()

    # Print status update
    if i % 100 == 0:
        print("...Person " + str(i))

    # train
    _, _, trans_func, _ = cal_q.compute_q_vals_dynamics(data_train_temp_samples,
                                                                      reward_mean,
                                                                      min,
                                                                      max,
                                                                      feat_sel, num_act=NUM_ACTIONS)

    # test on person that was left out from training
    data_person = data[data['rand_id'] == all_people[i]]

    for j in range(1, len(data_person)+1):

        data_person_session = data_person[data_person["session_num"] == j]
        state = list(np.take(data_person_session.iloc[0]['Binary_State'], feat_sel))
        next_state = list(np.take(data_person_session.iloc[0]['Binary_State_Next_Session'], feat_sel))
        state_idx = abstract_states.index(state)
        action = int(data_person_session.iloc[0]['cluster_new_index'])

        prob_next_state = trans_func[state_idx, action, abstract_states.index(next_state)]
        prob_per_state[state_idx].append(prob_next_state)
        prob_per_state_stay[state_idx].append(state == next_state)

prob_per_state_trans_mean = [np.mean(prob_per_state[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
prob_per_state_trans_std = [np.std(prob_per_state[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
print("\nMean likelihood of next state based on transition function:", np.round(prob_per_state_trans_mean, 2))
print("                                                        SD:", np.round(prob_per_state_trans_std, 2))
prob_per_state_stay_mean = [np.mean(prob_per_state_stay[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
prob_per_state_stay_std = [np.std(prob_per_state_stay[state_idx]) for state_idx in range(2**NUM_FEAT_TO_SELECT)]
print("\nMean likelihood of next state assuming people stay in state:", np.round(prob_per_state_stay_mean, 2))
print("                                                         SD:", np.round(prob_per_state_stay_std, 2))


#### CREATE FIGURE 3
sns.set_style("white")
med_fontsize = 22
small_fontsize = 18
extrasmall_fontsize = 15
sns.set_context("paper", rc={"font.size":med_fontsize,"axes.titlesize":med_fontsize,"axes.labelsize":med_fontsize,
                            'xtick.labelsize':small_fontsize, 'ytick.labelsize':small_fontsize,
                            'legend.fontsize':extrasmall_fontsize,'legend.title_fontsize': extrasmall_fontsize})

plt.figure(figsize=(10,5))

patterns = ["-", "x", ""]

x_vals = np.arange(2**NUM_FEAT_TO_SELECT)
num_bars = 3
width = 1/num_bars - 0.05

plt.bar(x_vals- 0.5 * num_bars * width, np.ones(2**NUM_FEAT_TO_SELECT) * 1/(2**NUM_FEAT_TO_SELECT),
        width, color = 'gray', label = "Equally likely next states", hatch = patterns[0])
plt.bar(x_vals- 0.5 * width, prob_per_state_stay_mean, width, color = 'orange',
        label="Stay in state", hatch = patterns[1])
plt.bar(x_vals+ 0.5 * width, prob_per_state_trans_mean, width, color = 'deepskyblue',
        label = "Transition function", hatch = patterns[2])

# Add Bayesian confidence intervals (i.e., credible intervals) for the mean
alpha = 0.95
for state in range(2**NUM_FEAT_TO_SELECT):
    conf_stay = scipy.stats.bayes_mvs(prob_per_state_stay[state], alpha = alpha)[0].minmax
    plt.vlines(x=state-0.5 * width, ymin=conf_stay[0], ymax=conf_stay[1], color='black')
    conf = scipy.stats.bayes_mvs(prob_per_state[state], alpha = alpha)[0].minmax
    plt.vlines(x=state+0.5 * width, ymin=conf[0], ymax=conf[1], color='black')

plt.ylim([0, 0.85])
plt.ylabel("Mean Likelihood\nof Next State")
plt.xlabel("State")
plt.xticks(x_vals, ["000", "001", "010", "011", "100", "101", "110", "111"])
plt.legend(loc="upper center")
plt.savefig("Figures/Figure_3.pdf", dpi=1500,
            bbox_inches='tight', pad_inches=0)
